Introduction to PyMC

IS 5150/6110

Probabilistic Inference

\[ \LARGE{p(\theta | X) \propto p(X | \theta) \ p(\theta)} \]

Probabilistic Programming

Probabilistic programming languages (PPLs) provide for direct computational specification of probability distributions and inference. They mean the real computational burden of Bayesian statistics is obviated.

  1. You specify your model in terms of probability functions.
  2. The PPL automatically compiles a sampler (or approximation) for your model.

Most samplers are some form of Markov chain Monte Carlo—with Hamiltonian Monte Carlo currently best in class. We’ll be using PyMC.

import numpy as np
import polars as pl
import pymc as pm
import arviz as az

np.random.seed(42)

# Set the parameter values.
beta0 = 3
beta1 = 7
sigma = 3
n = 100

x = np.random.uniform(0, 7, size = n)
y = beta0 + beta1 * x + np.random.normal(size = n) * sigma

# Create a Model object.
basic_model = pm.Model()

# Specify the model.
with basic_model:
  # Prior.
  beta = pm.Normal('beta', mu = 0, sigma = 10, shape = 2)
  sigma = pm.HalfNormal('sigma', sigma = 1)

  # Likelihood.
  mu = beta[0] + beta[1] * x
  y_obs = pm.Normal('y_obs', mu = mu, sigma = sigma, observed = y)

# Create an InferenceData object.
with basic_model:
  # Draw 1000 posterior samples.
  idata = pm.sample()

# Have we recovered the parameters?
az.summary(idata, round_to = 2)

# Visualize marginal posteriors.
az.plot_trace(idata, combined = True)

# Estimate the direct causal effect of avgfood on weight.
fit <- ulam(
  alist(
    weight ~ dnorm(mu, sigma),
    mu <- beta0 + beta_food * avgfood + beta_group * groupsize,
    beta0 ~ dnorm(0, 0.2),
    c(beta_food, beta_group) ~ dnorm(0, 0.5),
    sigma ~ dexp(1)
  ), 
  data = foxes_list, # Specify the data list instead of a data frame.
  chains = 4,        # Specify the number of chains.
  cores = 4,         # Specify the number of cores to run in parallel.
  log_lik = TRUE,    # To compute model fit via WAIC and PSIS.
  cmdstan = TRUE     # Specify cmdstan = TRUE to use cmdstanr instead of rstan.
)